-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(sampler): consolidate sampling interface, part 1 #5312
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #5312 +/- ##
==========================================
+ Coverage 83.33% 83.38% +0.04%
==========================================
Files 338 342 +4
Lines 18678 18745 +67
==========================================
+ Hits 15565 15630 +65
- Misses 3113 3115 +2
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
torch_geometric/sampler/base.py
Outdated
@abstractmethod | ||
def __init__( | ||
self, | ||
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are going to get painful to continue to support.. do we have "data to graphstore"? Can we assume that you will always convert to the new representation? Maybe that doesn't make sense...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is a great point. We do plan to consolidate behind Tuple[FeatureStore, GraphStore]
in the near future, but probably in a separate PR.
GraphStore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
torch_geometric/sampler/base.py
Outdated
# * edge: a tensor of the indices in the original graph; e.g. to be used to | ||
# obtain edge attributes. | ||
class SamplerOutput(NamedTuple): | ||
node: Union[torch.Tensor, Dict[NodeType, torch.Tensor]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to add an optional batch
attribute which assigns each node to an example. This will be necessary to integrate with the new pyg-lib
implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done; although this may become more complicated once we try to simplify 'metadata'
. Can revert later if necessary.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…ric into remote_backend_2
… 2 (#5365) This PR begins the effort to consolidate PyG's sampling interface in preparation for moving `sample(...)` behind the `GraphStore` interface. This effort is somewhat large in scope and will be broken into multiple PRs for ease of review. It resolves some TODOs from #5312, and introduces others that will be resolved in follow-up PRs. The major change removes `LinkNeighborSampler`, consolidating it under a the common `BaseSampler` interface with `NeighborSampler`. In doing so, it modifies the `BaseSampler` interface to support `sample_from_nodes` and `sample_from_edges`, since oftentimes edge-based sampling applies some transformations and then re-uses the same logic as node-based sampling. A sampler need not define both node-based and edge-based sampling, but if it doesn't, an appropriate error will be raised. Resolved TODOs: * `LinkNeighborSampler` and `NeighborSampler` interfaces are aligned * No more special handling of `edge_type_to_str` in `filter_*` * No more special handling of `perm_dict`, which removes the need for `edge_type_to_str` entirely
This PR begins the effort to consolidate PyG's sampling interface in preparation for moving
sample(...)
behind theGraphStore
interface. This effort is somewhat large in scope and will be broken into multiple PRs for ease of review. There are many TODOs left in this PR; all will be resolved in forthcoming PRs (the code will become cleaner in short order).This first change moves
NeighborSampler
behind a common base class with clearly defined inputs and outputs. Many more changes will be required before this base class defines a stable interface (e.g.LinkNeighborSampler
has different semantics,NeighborSampler.perm_dict
is used inNeighborLoader
, etc.). But for now, it will act as a first step towards this goal.It further refactors
loader/utils.py
into those that are required for sampling (moved intosampler/utils.py
) and those that are required for filtering (remain inloader/utils.py
).Note about integration with GraphStore. Moving sampling behind the graph store requires defining a clear interface for sampling across PyG. Since a graph store should not inherently be coupled with a sampling method used (e.g. one can imagine using a NeighborLoader and LinkNeighborLoader on the same graph store), we must define this interface independently from the GraphStore. One can then imagine the GraphStore defining supported samplers for different loaders (along with the defaults, which sample in-memory), and those samplers are then seamlessly used in the loaders.